import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# read input data from MNIST_data/ and enable one_hot option
mnist = input_data.read_data_sets("MNIST_data/",one_hot = True)

# new a session
sess = tf.InteractiveSession()

learning_rate = 1e-4
training_iters = 200000
batch_size = 50

# creat placeholder of input data: x and input label: y_
x = tf.placeholder("float",shape = [None,784])
y_ = tf.placeholder("float",shape = [None,10])
keep_prob = tf.placeholder("float")

def conv2d(name, x, W, b):
    return tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME'),b),name=name)

def max_pool2x2(name, x):
    return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME', name=name)

def norm(name, x):
    return tf.nn.lrn(x, 4, bias=1.0, alpha=0.01/9.0, beta=0.75, name=name)

def alex_net(x, Weights, Biases, Dropout):
    x = tf.reshape(x, shape=[-1,28,28,1])

    # convlution layer 1
    conv1 = conv2d('conv1', x, Weights['wc1'], Biases['bc1'])
    pool1 = max_pool2x2('pool1', conv1)
    #norm1 = norm('norm1', pool1)
    #norm1 = tf.nn.dropout(norm1, Dropout)

    # convlution layer 2
    conv2 = conv2d('conv2', pool1, Weights['wc2'], Biases['bc2'])
    pool2 = max_pool2x2('pool2', conv2)
    #norm2 = norm('norm2', pool2)
    #norm2 = tf.nn.dropout(norm2, Dropout)

    #convlution layer 3
    conv3 = conv2d('conv3', pool2, Weights['wc3'], Biases['bc3'])
    pool3 = max_pool2x2('pool3', conv3)
    #norm3 = norm('norm3', pool3)
    #norm3 = tf.nn.dropout(norm3, Dropout)

    #full connect layer1
    dense1 = tf.reshape(pool3, [-1, Weights['wd1'].get_shape().as_list()[0]])
    dense1 = tf.nn.relu(tf.matmul(dense1, Weights['wd1']) + Biases['bd1'], name='fc1')

    # full connect layer2
    dense2 = tf.nn.relu(tf.matmul(dense1, Weights['wd2']) + Biases['bd2'], name='fc2')

    #output layer
    out = tf.matmul(dense2, Weights['out']) + Biases['out']

    return out

weights = {
    'wc1': tf.Variable(tf.truncated_normal([3, 3, 1, 64], stddev=0.1)),
    'wc2': tf.Variable(tf.truncated_normal([3, 3, 64, 128], stddev=0.1)),
    'wc3': tf.Variable(tf.truncated_normal([3, 3, 128, 256], stddev=0.1)),
    'wd1': tf.Variable(tf.truncated_normal([4 * 4 * 256, 1024], stddev=0.1)),
    'wd2': tf.Variable(tf.truncated_normal([1024, 1024], stddev=0.1)),
    'out': tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
}
biases = {
    'bc1': tf.Variable(tf.random_normal([64])),
    'bc2': tf.Variable(tf.random_normal([128])),
    'bc3': tf.Variable(tf.random_normal([256])),
    'bd1': tf.Variable(tf.random_normal([1024])),
    'bd2': tf.Variable(tf.random_normal([1024])),
    'out': tf.Variable(tf.random_normal([10]))
}

pred = alex_net(x, weights, biases, keep_prob)

#cost function: cross entropy
# cost = -tf.reduce_sum(y_*tf.log(pred))
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=pred))

#optimizer: Adam optimizer
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)

#test accuracy
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

saver = tf.train.Saver()

init = tf.global_variables_initializer()

'''
with tf.Session() as sess:
    sess.run(init)
    for i in range(34000):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)

        sess.run(optimizer, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 0.8})

        if i % 100 == 0:
            train_accuracy = sess.run(accuracy, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 1.})
            print("step %d, training accuracy = %g" % (i, train_accuracy))

    print("test accuracy = %g" % accuracy.eval(feed_dict={
        x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0
    }))
'''
if __name__ == '__main__':
    sess.run(tf.global_variables_initializer())
    for i in range(34000):
        batch = mnist.train.next_batch(50)
        if i%100 == 0:
            train_accuracy = accuracy.eval(feed_dict={
                x:batch[0], y_:batch[1], keep_prob:1.0
            })
            print("step %d, training accuracy = %g"%(i, train_accuracy))
        optimizer.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5})

    print("test accuracy = %g"%accuracy.eval(feed_dict={
        x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0
    }))

    save_path = saver.save(sess, "model/AlexNet.ckpt")
    print("model has been saved into %s" % save_path)